from blockprocessing import BlockProcessing
from utils import save, eval_psnr
import matplotlib.pyplot as plt
import numpy as np
import h5py
from argparse import ArgumentParser


def parse_args():
    p = ArgumentParser()
    p.add_argument("--train-out", required=True, help="output of TVAE training")
    p.add_argument("--image", required=True, help="original image as 'data' dataset in an HDF5 file")
    p.add_argument("--patch-size", required=True, help="patch size, e.g. 12 for 12x12 patches")
    p.add_argument(
        "--out-file-stem",
        required=True,
        help="stem for output files (must include parent directories if any)",
    )
    return p.parse_args()


def depatchify(reco_patches, image):
    """Take output of TVEM training, depatchify it, evaluate PSNR."""
    mask = np.zeros_like(image)  # 0 -> to reconstruct
    img_copy = image.copy()  # bp modifies the input
    patch_size = int(np.sqrt(reco_patches.shape[1]))

    bp = BlockProcessing(
        img_copy,
        mask=mask,
        patchheight=patch_size,
        patchwidth=patch_size,
        pp_params={"pp_type": None, "sf_type": "gauss_"},
    )
    bp.im2bl()
    bp.Y[:] = reco_patches.T
    bp.bl2im()
    psnr = eval_psnr(image, bp.I)
    return bp.I, psnr


def main():
    """Produce HDF5 and PNG results."""
    args = parse_args()
    out_file_stem = args.out_file_stem
    train_out = args.train_out
    with h5py.File(train_out, "r") as train_out:
        patches = train_out["train_reconstruction"][...]
    with h5py.File(args.image, "r") as original_image:
        orig_img = original_image["original"][...]

    img, psnr = depatchify(patches, orig_img)

    save(f"{out_file_stem}.h5", {"data": img, "psnr": psnr})
    plt.imshow(img, cmap="gray")
    plt.axis("off")
    plt.title(f"Denoised image (PSNR={psnr:.2f})")
    plt.savefig(f"{out_file_stem}.png")


if __name__ == "__main__":
    main()
